import json
import torch
from torch.utils.data import Dataset
import random
from sentence_transformers import SentenceTransformer

random.seed(42)

class TextPairDataset(Dataset):
    def __init__(self, json_file, tokenizer=None, model_name=None, limit=0, negative_types=None):
        if model_name:
            self.model = SentenceTransformer(model_name)
            self.tokenizer = self.model.tokenizer
        else:
            self.tokenizer = tokenizer

        with open(json_file, "r", encoding="utf-8") as f:
            raw_data = json.load(f)

        self.samples = []
        if negative_types is None:
            self.negative_types_keys = [ # Store the keys for reference
                "neg_type_1_tokens",
                "neg_type_2_tokens",
                "neg_type_3_tokens",
                "neg_type_4_tokens"
            ]
        else:
            self.negative_types_keys = negative_types


        for item in raw_data:
            reason = item["reason"]

            pos_list = item.get("pos_token_texts_list", [])
            for pos_tokens in pos_list:
                pos_ids = self.tokenizer.convert_tokens_to_ids(pos_tokens)
                self.samples.append({
                    "token_ids": pos_ids,
                    "reason": reason,
                    "label": 1,
                    "sample_type": "positive" 
                })

            for neg_type_key in self.negative_types_keys: # Use the stored keys
                neg_tokens = item.get("negatives", {}).get(neg_type_key, [])
                if neg_tokens: 
                    neg_ids = self.tokenizer.convert_tokens_to_ids(neg_tokens)
                    self.samples.append({
                        "token_ids": neg_ids,
                        "reason": reason,
                        "label": 0,
                        "sample_type": neg_type_key  
                    })

        if limit > 0:
            self.samples = random.sample(self.samples, min(limit, len(self.samples)))

        random.shuffle(self.samples)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        return sample["token_ids"], sample["reason"], sample["label"], sample["sample_type"]